{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install the library\n",
    "%pip install pythae"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision.datasets as datasets\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist_trainset = datasets.MNIST(root='../../data', train=True, download=True, transform=None)\n",
    "\n",
    "train_dataset = mnist_trainset.data[:-10000].reshape(-1, 1, 28, 28) / 255.\n",
    "train_targets = mnist_trainset.targets[:-10000]\n",
    "eval_dataset = mnist_trainset.data[-10000:].reshape(-1, 1, 28, 28) / 255.\n",
    "eval_targets = mnist_trainset.targets[-10000:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pythae.models import PoincareVAE, PoincareVAEConfig\n",
    "from pythae.trainers import BaseTrainerConfig\n",
    "from pythae.pipelines.training import TrainingPipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let's define some custom Encoder/Decoder to stick to the paper proposal\n",
    "import math\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from pythae.models.nn import BaseEncoder, BaseDecoder\n",
    "from pythae.models.base.base_utils import ModelOutput\n",
    "from pythae.models.pvae.pvae_utils import PoincareBall\n",
    "\n",
    "class RiemannianLayer(nn.Module):\n",
    "    def __init__(self, in_features, out_features, manifold, over_param, weight_norm):\n",
    "        super(RiemannianLayer, self).__init__()\n",
    "        self.in_features = in_features\n",
    "        self.out_features = out_features\n",
    "        self.manifold = manifold\n",
    "        self._weight = nn.Parameter(torch.Tensor(out_features, in_features))\n",
    "        self.over_param = over_param\n",
    "        self.weight_norm = weight_norm\n",
    "        self._bias = nn.Parameter(torch.Tensor(out_features, 1))\n",
    "        self.reset_parameters()\n",
    "\n",
    "    @property\n",
    "    def weight(self):\n",
    "        return self.manifold.transp0(self.bias, self._weight) # weight \\in T_0 => weight \\in T_bias\n",
    "\n",
    "    @property\n",
    "    def bias(self):\n",
    "        if self.over_param:\n",
    "            return self._bias\n",
    "        else:\n",
    "            return self.manifold.expmap0(self._weight * self._bias) # reparameterisation of a point on the manifold\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        nn.init.kaiming_normal_(self._weight, a=math.sqrt(5))\n",
    "        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self._weight)\n",
    "        bound = 4 / math.sqrt(fan_in)\n",
    "        nn.init.uniform_(self._bias, -bound, bound)\n",
    "        if self.over_param:\n",
    "            with torch.no_grad(): self._bias.set_(self.manifold.expmap0(self._bias))\n",
    "\n",
    "class GeodesicLayer(RiemannianLayer):\n",
    "    def __init__(self, in_features, out_features, manifold, over_param=False, weight_norm=False):\n",
    "        super(GeodesicLayer, self).__init__(in_features, out_features, manifold, over_param, weight_norm)\n",
    "\n",
    "    def forward(self, input):\n",
    "        input = input.unsqueeze(0)\n",
    "        input = input.unsqueeze(-2).expand(*input.shape[:-(len(input.shape) - 2)], self.out_features, self.in_features)\n",
    "        res = self.manifold.normdist2plane(input, self.bias, self.weight,\n",
    "                                               signed=True, norm=self.weight_norm)\n",
    "        return res\n",
    "\n",
    "### Define paper encoder network\n",
    "class Encoder(BaseEncoder):\n",
    "    \"\"\" Usual encoder followed by an exponential map \"\"\"\n",
    "    def __init__(self, model_config, prior_iso=False):\n",
    "        super(Encoder, self).__init__()\n",
    "        self.manifold = PoincareBall(dim=model_config.latent_dim, c=model_config.curvature)\n",
    "        self.enc = nn.Sequential(\n",
    "            nn.Linear(np.prod(model_config.input_dim), 600), nn.ReLU(),\n",
    "        )\n",
    "        self.fc21 = nn.Linear(600, model_config.latent_dim)\n",
    "        self.fc22 = nn.Linear(600, model_config.latent_dim if not prior_iso else 1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        e = self.enc(x.reshape(x.shape[0], -1))\n",
    "        mu = self.fc21(e)\n",
    "        mu = self.manifold.expmap0(mu)\n",
    "        return ModelOutput(\n",
    "            embedding=mu,\n",
    "            log_covariance=torch.log(F.softplus(self.fc22(e)) + 1e-5), # expects log_covariance\n",
    "            log_concentration=torch.log(F.softplus(self.fc22(e)) + 1e-5) # for Riemannian Normal\n",
    "\n",
    "        )\n",
    "\n",
    "### Define paper decoder network\n",
    "class Decoder(BaseDecoder):\n",
    "    \"\"\" First layer is a Hypergyroplane followed by usual decoder \"\"\"\n",
    "    def __init__(self, model_config):\n",
    "        super(Decoder, self).__init__()\n",
    "        self.manifold = PoincareBall(dim=model_config.latent_dim, c=model_config.curvature)\n",
    "        self.input_dim = model_config.input_dim\n",
    "        self.dec = nn.Sequential(\n",
    "            GeodesicLayer(model_config.latent_dim, 600, self.manifold),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(600, np.prod(model_config.input_dim)),\n",
    "            nn.Sigmoid()\n",
    "        )\n",
    "\n",
    "    def forward(self, z):\n",
    "        out = self.dec(z).reshape((z.shape[0],) + self.input_dim)  # reshape data\n",
    "        return ModelOutput(\n",
    "            reconstruction=out\n",
    "        )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = BaseTrainerConfig(\n",
    "    output_dir='my_model',\n",
    "    learning_rate=5e-4,\n",
    "    per_device_train_batch_size=64,\n",
    "    per_device_eval_batch_size=64,\n",
    "    num_epochs=10, # Change this to train the model a bit more\n",
    ")\n",
    "\n",
    "\n",
    "model_config = PoincareVAEConfig(\n",
    "    input_dim=(1, 28, 28),\n",
    "    latent_dim=2,\n",
    "    reconstruction_loss=\"bce\",\n",
    "    prior_distribution=\"riemannian_normal\",\n",
    "    posterior_distribution=\"wrapped_normal\",\n",
    "    curvature=0.7\n",
    ")\n",
    "\n",
    "model = PoincareVAE(\n",
    "    model_config=model_config,\n",
    "    encoder=Encoder(model_config), \n",
    "    decoder=Decoder(model_config) \n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipeline = TrainingPipeline(\n",
    "    training_config=config,\n",
    "    model=model\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipeline(\n",
    "    train_data=train_dataset,\n",
    "    eval_data=eval_dataset\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from pythae.models import AutoModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "last_training = sorted(os.listdir('my_model'))[-1]\n",
    "trained_model = AutoModel.load_from_folder(os.path.join('my_model', last_training, 'final_model')).to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualize latent space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "colors = sns.color_palette('pastel')\n",
    "\n",
    "fig = plt.figure(figsize=(10,8))\n",
    "\n",
    "label = eval_targets\n",
    "\n",
    "torch.manual_seed(42)\n",
    "idx = torch.randperm(len(eval_dataset))\n",
    "with torch.no_grad():\n",
    "    mu = trained_model.encoder(eval_dataset.to(device)).embedding.detach().cpu()\n",
    "plt.scatter(mu[:, 0], mu[:, 1], c=label, cmap=matplotlib.colors.ListedColormap(colors))\n",
    "\n",
    "cb = plt.colorbar()\n",
    "loc = np.arange(0,max(label),max(label)/float(len(colors)))\n",
    "cb.set_ticks(loc)\n",
    "cb.set_ticklabels([f'{i}' for i in range(10)])\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pythae.samplers import PoincareDiskSampler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create normal sampler\n",
    "pvae_samper = PoincareDiskSampler(\n",
    "    model=trained_model\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sample\n",
    "gen_data = pvae_samper.sample(\n",
    "    num_samples=25\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# show results with normal sampler\n",
    "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n",
    "\n",
    "for i in range(5):\n",
    "    for j in range(5):\n",
    "        axes[i][j].imshow(gen_data[i*5 +j].cpu().squeeze(0), cmap='gray')\n",
    "        axes[i][j].axis('off')\n",
    "plt.tight_layout(pad=0.)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ... the other samplers work the same"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualizing reconstructions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reconstructions = trained_model.reconstruct(eval_dataset[:25].to(device)).detach().cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# show reconstructions\n",
    "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n",
    "\n",
    "for i in range(5):\n",
    "    for j in range(5):\n",
    "        axes[i][j].imshow(reconstructions[i*5 + j].cpu().squeeze(0), cmap='gray')\n",
    "        axes[i][j].axis('off')\n",
    "plt.tight_layout(pad=0.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# show the true data\n",
    "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n",
    "\n",
    "for i in range(5):\n",
    "    for j in range(5):\n",
    "        axes[i][j].imshow(eval_dataset[i*5 +j].cpu().squeeze(0), cmap='gray')\n",
    "        axes[i][j].axis('off')\n",
    "plt.tight_layout(pad=0.)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualizing interpolations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "interpolations = trained_model.interpolate(eval_dataset[:5].to(device), eval_dataset[5:10].to(device), granularity=10).detach().cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# show interpolations\n",
    "fig, axes = plt.subplots(nrows=5, ncols=10, figsize=(10, 5))\n",
    "\n",
    "for i in range(5):\n",
    "    for j in range(10):\n",
    "        axes[i][j].imshow(interpolations[i, j].cpu().squeeze(0), cmap='gray')\n",
    "        axes[i][j].axis('off')\n",
    "plt.tight_layout(pad=0.)"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "3efa06c4da850a09a4898b773c7e91b0da3286dbbffa369a8099a14a8fa43098"
  },
  "kernelspec": {
   "display_name": "Python 3.8.11 64-bit ('pythae_dev': conda)",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
